import argparse
import os
import scipy.stats
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import auc, roc_curve
import functools
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


def args_parser():
    parser = argparse.ArgumentParser(description='plot ROC curve')
    parser.add_argument('--num_shadow', type=int, required=True, help='the number of shadow models')
    parser.add_argument('--lira_path', type=str, required=True, help='the path of saved results')
    parser.add_argument('--augmentation', type=str, default='True', choices=['True', 'False'], help='whether to use data augmentation')
    parser.add_argument('--model_type', type=str, choices=['ConvNet', 'ResNet18', 'ResNet18BN'], help='The model type to use')
    parser.add_argument('--use_dd_aug', action='store_true', help='whether to use transforms in DD')
    parser.add_argument('--avg_case', action='store_true', default=False, help='use average case in-out split')
    args = parser.parse_args()
    
    return args


def sweep(score, x):
    """
    Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
    """
    
    fpr, tpr, _ = roc_curve(x, -score)  
    acc = np.max(1-(fpr+(1-tpr))/2)
    return fpr, tpr, auc(fpr, tpr), acc



def generate_ours_online(train_keep, train_scores, test_scores, test_keep, in_size=100000, out_size=100000, fix_variance=False):
    dat_in = []
    dat_out = []
    # split into in and out data
    
    for i in range(train_scores.shape[1]):    
        dat_in.append(train_scores[:,i,:][train_keep[:, i] == 1])  
        dat_out.append(train_scores[:,i,:][train_keep[:, i] == 0])

    in_size = min(min(map(len,dat_in)), in_size)      
    out_size = min(min(map(len,dat_out)), out_size)    
    
    dat_in = np.array([x[:in_size] for x in dat_in])    
    dat_out = np.array([x[:out_size] for x in dat_out]) 
    #! original version use np.median() instead of np.mean()
    mean_in = np.mean(dat_in, 1)  
    mean_out = np.mean(dat_out, 1)

    if fix_variance:
        std_in = np.std(dat_in)
        std_out = np.std(dat_in)
    else:
        std_in = np.std(dat_in, 1)
        std_out = np.std(dat_out, 1)

    prediction = []
    answers = []
    
    for ans, sc in zip(test_keep, test_scores):
        
        pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30) 
        pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
        # if pr_in > pr_out, then the example is in the in-data.
        score = pr_in-pr_out    
        prediction.extend(score.mean(1))    
        answers.extend(ans)

    return prediction, answers
    
def generate_ours_offline(train_keep, train_scores, test_scores, test_keep, in_size=100000, out_size=100000, fix_variance=False):
    """
    Fit a single predictive model using keep and scores in order to predict
    if the examples in check_scores were training data or not, using the
    ground truth answer from check_keep.
    """
    dat_in = []
    dat_out = []

    for i in range(train_scores.shape[1]):   
        dat_in.append(train_scores[:,i,:][train_keep[:, i] == 1])
        dat_out.append(train_scores[:,i,:][train_keep[:, i] == 0])

    out_size = min(min(map(len,dat_out)), out_size)

    dat_out = np.array([x[:out_size] for x in dat_out])

    mean_out = np.median(dat_out, 1)

    if fix_variance:
        std_out = np.std(dat_out)
    else:
        std_out = np.std(dat_out, 1)
        
    prediction = []
    answers = []
    for ans, sc in zip(test_keep, test_scores):
        score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
        score = score.reshape(score.shape[0], -1)
        prediction.extend(score.mean(1))
        answers.extend(ans)
    return prediction, answers

def generate_global(train_keep, train_scores, test_scores, test_keep,):
    """
    Use a simple global threshold sweep to predict if the examples in
    check_scores were training data or not, using the ground truth answer from
    check_keep.
    """
    prediction = []
    answers = []
    for ans, sc in zip(test_keep, test_scores):
        prediction.extend(-sc.mean(1))
        answers.extend(ans)

    return prediction, answers


def plot_ROC_curve(canaries_mask, keep, scores, fun):
    plt.figure(figsize=(4,3))
    
    # evaluate LiRA on the points that were targeted by poisoning
    # FIXME
    # do_plot_all(
    #     functools.partial(generate_ours_online, fix_variance=True),
    #     keep[:, canaries_mask], scores[:, canaries_mask],
    #     "With noisy (LiRA, fixed var)\n",
    # )

    do_plot_all(
        functools.partial(generate_ours_online, fix_variance=False),
        keep[:, canaries_mask], scores[:, canaries_mask],
        "With noisy (LiRA)\n",
    )
    
    # evaluate the global-threshold attack on the points that were targeted by poisoning
    do_plot_all(
        generate_global,
        keep[:, canaries_mask], scores[:, canaries_mask],
        "With noisy (Global threshold)\n",
    )
    
    plt.semilogx()
    plt.semilogy()
    plt.xlim(1e-4,1)
    plt.ylim(1e-4,1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.plot([0, 1], [0, 1], ls='--', color='gray')
    plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
    plt.legend(fontsize=8)
    plt.savefig(os.path.join(args.lira_path, f"fprtpr_{fun.__name__}.png"))
    plt.show()


def do_plot_all(fn, keep, scores, legend='', ):
    """
    Generate the ROC curves by using one model as test model and the rest to train,
    with a full leave-one-out cross-validation.
    """

    all_predictions = []
    all_answers = []
    
    for i in range(0, len(keep)):
        mask = np.zeros(len(keep), dtype=bool)
        mask[i:i+1] = True
        prediction, answers = fn(keep[~mask], scores[~mask], scores[mask], keep[mask])
        all_predictions.extend(prediction)
        all_answers.extend(answers)
    print(np.array(all_predictions).shape)
    fpr, tpr, auc, acc = sweep(np.array(all_predictions),
                                    np.array(all_answers, dtype=bool))

    low = tpr[np.where(fpr < .001)[0][-1]]
    print('Attack %s   AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))

    metric_text = 'auc=%.3f'%auc

    plt.plot(fpr, tpr, label=legend+metric_text,)




if __name__ == "__main__":
    args = args_parser()
    args.augmentation = True if args.augmentation == 'True' else False

    suffix = f'{args.model_type}_dd_aug' if args.use_dd_aug else f'{args.model_type}'
    score_path = os.path.join(args.lira_path, f"scores_{suffix}")

    scores, indices = [], []
    for idx in range(0, args.num_shadow):
        scores.append(np.load(os.path.join(score_path, f"score_{idx}.npy")))  
        indices.append(np.load(os.path.join(args.lira_path, f"indices/indice_{idx}.npy")))
    scores = np.array(scores)  
    keep = []
    for i in range(scores.shape[0]):  
        mask = np.zeros(scores.shape[1])
        mask[indices[i]]=1
        keep.append(mask)

    # for mislabeled canaries
    canaries_mask = np.zeros(scores.shape[1], dtype=bool)
    if not args.avg_case:
        canary_indices = np.load(os.path.join(args.lira_path, "canary_indices.npy"))
        canaries_mask[canary_indices] = True
    else:
        canaries_mask[:] = True

    print(scores.shape, canaries_mask.shape, keep[0].shape)
    scores = scores if args.augmentation else np.expand_dims(scores[:, :, 0], axis=-1)
    
    plot_ROC_curve(
        canaries_mask, 
        np.array(keep), 
        scores, 
        do_plot_all
    )
